17. 参数微调

把权重和偏置项加载到新模型中

很多时候你想调整,或者说“微调”一个你已经训练并保存了的模型。但是,把保存的变量直接加载到已经修改过的模型会产生错误。让我们看看如何解决这个问题。

命名报错

TensorFlow 对 Tensor 和计算使用一个叫 name 的字符串辨识器,如果没有定义 name,TensorFlow 会自动创建一个。TensorFlow 会把第一个节点命名为 <Type>,把后续的命名为<Type>_<number>。让我们看看这对加载一个有不同顺序权重和偏置项的模型有哪些影响:

import tensorflow as tf

# Remove the previous weights and bias
# 移除先前的权重和偏置项
tf.reset_default_graph()

save_file = 'model.ckpt'

# Two Tensor Variables: weights and bias
# 两个 Tensor 变量:权重和偏置项
weights = tf.Variable(tf.truncated_normal([2, 3]))
bias = tf.Variable(tf.truncated_normal([3]))

saver = tf.train.Saver()

# Print the name of Weights and Bias
# 打印权重和偏置项的名字
print('Save Weights: {}'.format(weights.name))
print('Save Bias: {}'.format(bias.name))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

# Remove the previous weights and bias
# 移除之前的权重和偏置项
tf.reset_default_graph()

# Two Variables: weights and bias
# 两个变量:权重和偏置项
bias = tf.Variable(tf.truncated_normal([3]))
weights = tf.Variable(tf.truncated_normal([2, 3]))

saver = tf.train.Saver()

# Print the name of Weights and Bias
# 打印权重和偏置项的名字
print('Load Weights: {}'.format(weights.name))
print('Load Bias: {}'.format(bias.name))

with tf.Session() as sess:
    # Load the weights and bias - ERROR
    # 加载权重和偏置项 - 报错
    saver.restore(sess, save_file)

上述代码会有下列输出:

Save Weights: Variable:0

Save Bias: Variable_1:0

Load Weights: Variable_1:0

Load Bias: Variable:0

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match.

你注意到,weightsbiasname 属性与你保存的模型不同。这是为什么代码报“Assign requires shapes of both tensors to match”这个错误。saver.restore(sess, save_file) 代码试图把权重数据加载到bias里,把偏置项数据加载到 weights里。

与其让 TensorFlow 来设定 name 属性,不如让我们来手动设定:

import tensorflow as tf

tf.reset_default_graph()

save_file = 'model.ckpt'

# Two Tensor Variables: weights and bias
# 两个 Tensor 变量:权重和偏置项
weights = tf.Variable(tf.truncated_normal([2, 3]), name='weights_0')
bias = tf.Variable(tf.truncated_normal([3]), name='bias_0')

saver = tf.train.Saver()

# Print the name of Weights and Bias
# 打印权重和偏置项的名称
print('Save Weights: {}'.format(weights.name))
print('Save Bias: {}'.format(bias.name))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

# Remove the previous weights and bias
# 移除之前的权重和偏置项
tf.reset_default_graph()

# Two Variables: weights and bias
# 两个变量:权重和偏置项
bias = tf.Variable(tf.truncated_normal([3]), name='bias_0')
weights = tf.Variable(tf.truncated_normal([2, 3]) ,name='weights_0')

saver = tf.train.Saver()

# Print the name of Weights and Bias
# 打印权重和偏置项的名称
print('Load Weights: {}'.format(weights.name))
print('Load Bias: {}'.format(bias.name))

with tf.Session() as sess:
    # Load the weights and bias - No Error
    # 加载权重和偏置项 - 没有报错
    saver.restore(sess, save_file)

print('Loaded Weights and Bias successfully.')

Save Weights: weights_0:0

Save Bias: bias_0:0

Load Weights: weights_0:0

Load Bias: bias_0:0

Loaded Weights and Bias successfully.

这次没问题!Tensor 名称匹配正确,数据被正确加载。